{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate features for nuclei and generate .pt files for each graph\n",
    "## 15 features in total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###TCGA-27-2528-01Z-00-DX1.8160EE21-D3C3-4FFF-8BB6-979235C1F963_2\n",
    "# val:ZT76_39_B_1_5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re, os\n",
    "import cv2\n",
    "import math\n",
    "import random\n",
    "import torch\n",
    "import resnet\n",
    "import skimage.feature\n",
    "import pdb\n",
    "from PIL import Image\n",
    "from pyflann import *\n",
    "from torch_geometric.data import Data\n",
    "from collections import OrderedDict\n",
    "\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torchvision.transforms.functional as F\n",
    "import torch_geometric.data as data\n",
    "import torch_geometric.utils as utils\n",
    "import pdb\n",
    "import torch_geometric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from model import CPC_model\n",
    "device = torch.device('cuda:{}'.format('0'))\n",
    "model = CPC_model(1024, 256)\n",
    "encoder = model.encoder.to(device)\n",
    "ckpt_dir = './pretrained_models/cpc_best.pt'\n",
    "ckpt = torch.load(ckpt_dir)\n",
    "encoder.load_state_dict(ckpt['encoder_state_dict'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResNet(\n",
       "  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "  (relu): ReLU(inplace=True)\n",
       "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  (layer1): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      )\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (2): Bottleneck(\n",
       "      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer2): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "      )\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (2): Bottleneck(\n",
       "      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (3): Bottleneck(\n",
       "      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer3): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "      )\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (2): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (3): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (4): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (5): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn((1, 3, 64, 64)).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def from_networkx(G):\n",
    "    r\"\"\"Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a\n",
    "    :class:`torch_geometric.data.Data` instance.\n",
    "    Args:\n",
    "        G (networkx.Graph or networkx.DiGraph): A networkx graph.\n",
    "    \"\"\"\n",
    "\n",
    "    G = G.to_directed() if not nx.is_directed(G) else G\n",
    "    edge_index = torch.tensor(list(G.edges)).t().contiguous()\n",
    "\n",
    "    keys = []\n",
    "    keys += list(list(G.nodes(data=True))[0][1].keys())\n",
    "    keys += list(list(G.edges(data=True))[0][2].keys())\n",
    "    data = {key: [] for key in keys}\n",
    "\n",
    "    for _, feat_dict in G.nodes(data=True):\n",
    "        for key, value in feat_dict.items():\n",
    "            data[key].append(value)\n",
    "\n",
    "    for _, _, feat_dict in G.edges(data=True):\n",
    "        for key, value in feat_dict.items():\n",
    "            data[key].append(value)\n",
    "\n",
    "    for key, item in data.items():\n",
    "        data[key] = torch.tensor(item)\n",
    "\n",
    "    data['edge_index'] = edge_index\n",
    "    data = torch_geometric.data.Data.from_dict(data)\n",
    "    data.num_nodes = G.number_of_nodes()\n",
    "\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cell_image_og(img, cx, cy):\n",
    "    if cx < 32 and cy < 32:\n",
    "        return img[0: cy+32, 0:cx+32, :]\n",
    "    elif cx < 32:\n",
    "        return img[cy-32: cy+32, 0:cx+32, :] \n",
    "    elif cy < 32:\n",
    "        return img[0: cy+32, cx-32:cx+32, :]\n",
    "    else:\n",
    "        return img[cy-32: cy+32, cx-32:cx+32, :]\n",
    "    \n",
    "def my_transform(img):\n",
    "    img = F.to_tensor(img)\n",
    "    img = F.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "    return img\n",
    "\n",
    "def get_cpc_features_og(cell):\n",
    "    \n",
    "    cell_R = np.squeeze(cell[:, :, 2:3])\n",
    "    cell_G = np.squeeze(cell[:, :, 1:2])\n",
    "    cell_B = np.squeeze(cell[:, :, 0:1])\n",
    "\n",
    "    cell_R = np.pad(cell_R, [(0, 64-cell_R.shape[0]), (0, 64-cell_R.shape[1])], mode = 'constant')\n",
    "    cell_G = np.pad(cell_G, [(0, 64-cell_G.shape[0]), (0, 64-cell_G.shape[1])], mode = 'constant')\n",
    "    cell_B = np.pad(cell_B, [(0, 64-cell_B.shape[0]), (0, 64-cell_B.shape[1])], mode = 'constant')\n",
    "    cell = np.stack((cell_R, cell_B, cell_G))\n",
    "    \n",
    "    cell = np.transpose(cell, (1, 2, 0))\n",
    "    cell = my_transform(cell)\n",
    "    cell = cell.unsqueeze(0)\n",
    "    \n",
    "    device = torch.device('cuda:{}'.format('0'))\n",
    "    \n",
    "    feats = encoder(cell.to(device)).cpu().detach().numpy()\n",
    "    return feats\n",
    "    #feats_cpu = [f.cpu().detach().numpy() for f in feats]\n",
    "    #return feats_cpu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import transforms\n",
    "import itertools\n",
    "\n",
    "def get_cell_image(img, cx, cy, size=512):\n",
    "    cx = 32 if cx < 32 else size-32 if cx > size-32 else cx\n",
    "    cy = 32 if cy < 32 else size-32 if cy > size-32 else cy\n",
    "    return img[cy-32:cy+32, cx-32:cx+32, :]\n",
    "\n",
    "def get_cpc_features(cell):\n",
    "    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "    cell = transform(cell)\n",
    "    cell = cell.unsqueeze(0)\n",
    "    device = torch.device('cuda:{}'.format('0'))\n",
    "    feats = encoder(cell.to(device)).cpu().detach().numpy()[0]\n",
    "    return feats\n",
    "\n",
    "def get_cell_features(img, contour):\n",
    "    \n",
    "    # Get contour coordinates from contour\n",
    "    (cx, cy), (short_axis, long_axis), angle = cv2.fitEllipse(contour)\n",
    "    cx, cy = int(cx), int(cy)\n",
    "    \n",
    "    # Get a 64 x 64 center crop over each cell    \n",
    "    img_cell = get_cell_image(img, cx, cy)\n",
    "\n",
    "    grey_region = cv2.cvtColor(img_cell, cv2.COLOR_RGB2GRAY)\n",
    "    img_cell_grey = np.pad(grey_region, [(0, 64-grey_region.shape[0]), (0, 64-grey_region.shape[1])], mode = 'reflect') \n",
    "\n",
    "\n",
    "    # 1. Generating contour features\n",
    "    eccentricity = math.sqrt(1-(short_axis/long_axis)**2)\n",
    "    convex_hull = cv2.convexHull(contour)\n",
    "    area, hull_area = cv2.contourArea(contour), cv2.contourArea(convex_hull)\n",
    "    solidity = float(area)/hull_area\n",
    "    arc_length = cv2.arcLength(contour, True)\n",
    "    roundness = (arc_length/(2*math.pi))/(math.sqrt(area/math.pi))\n",
    "    \n",
    "    # 2. Generating GLCM features\n",
    "    out_matrix = skimage.feature.greycomatrix(img_cell_grey, [1], [0])\n",
    "    dissimilarity = skimage.feature.greycoprops(out_matrix, 'dissimilarity')[0][0]\n",
    "    homogeneity = skimage.feature.greycoprops(out_matrix, 'homogeneity')[0][0]\n",
    "    energy = skimage.feature.greycoprops(out_matrix, 'energy')[0][0]\n",
    "    ASM = skimage.feature.greycoprops(out_matrix, 'ASM')[0][0]\n",
    "    \n",
    "    # 3. Generating CPC features\n",
    "    cpc_feats = get_cpc_features(img_cell)\n",
    "    \n",
    "\n",
    "    # Concatenate + Return all features\n",
    "    x = [[short_axis, long_axis, angle, area, arc_length, eccentricity, roundness, solidity],\n",
    "         [dissimilarity, homogeneity, energy, ASM], \n",
    "         cpc_feats]\n",
    "    \n",
    "    return np.array(list(itertools.chain(*x)), dtype=np.float64), cx, cy\n",
    "\n",
    "\n",
    "def seg2graph(img, contours):\n",
    "    G = nx.Graph()\n",
    "    \n",
    "    contours = [c for c in contours if c.shape[0] > 5]\n",
    "\n",
    "    for v, contour in enumerate(contours):\n",
    "\n",
    "        features, cx, cy = get_cell_features(img, contour)\n",
    "        G.add_node(v, centroid = [cx, cy], x = features)\n",
    "\n",
    "    if v < 5: return None\n",
    "    return G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"/media/hdd1/Jingwen/codes/staintools/\"\n",
    "save_dir = os.path.join(data_dir, 'KIRC_st_cpc_blue')\n",
    "img_dir = os.path.join(data_dir, 'KIRC_st')\n",
    "seg_dir =  os.path.join(data_dir,'KIRC_st_seg')\n",
    "\n",
    "roi1 = 'TCGA-B0-4839-01Z-00-DX1.0c13b082-d7e6-4327-b5cc-ab6bd99b78aa_roi_2_x_74016_y_24064_99.388.png'\n",
    "roi2 = 'TCGA-B0-4694-01Z-00-DX1.beeee877-b7f3-4110-84e9-2db8be5667f5_roi_1_x_86528_y_47936_95.186.png'\n",
    "assert roi1 in os.listdir(seg_dir)\n",
    "assert roi2 in os.listdir(seg_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"/media/hdd1/Jingwen/codes/staintools/\"\n",
    "save_dir = os.path.join(data_dir, 'all_st_cpc_img_blue')\n",
    "img_dir = os.path.join(data_dir, 'all_st')\n",
    "seg_dir =  os.path.join(data_dir,'all_st_seg')\n",
    "\n",
    "roi1 = 'TCGA-06-0174-01Z-00-DX3.23b6e12e-dfc1-4c6f-903e-170038a0e055_1.png'\n",
    "roi2 = 'TCGA-HT-7470-01Z-00-DX4.204D0CF2-A22E-4428-8E8C-572432B86280_1.png'\n",
    "roi3 = 'TCGA-26-1442-01Z-00-DX1.FD8D4EB7-AD5E-49E8-BD0B-6CDDEA8DDB35_1.png'\n",
    "\n",
    "assert roi1 in os.listdir(seg_dir)\n",
    "assert roi2 in os.listdir(seg_dir)\n",
    "assert roi3 in os.listdir(seg_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = seg2graph(img, contours)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([], shape=(64, 0, 3), dtype=uint8)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11961/11961 [01:25<00:00, 140.49it/s]\n"
     ]
    }
   ],
   "source": [
    "img_dir = os.path.join(data_dir, 'KIRC_st')\n",
    "for img_fname in tqdm(os.listdir(seg_dir)):\n",
    "    if int(img_fname.split('_')[2]) > 2: continue\n",
    "    os.system('cp %s %s' % (os.path.join(data_dir, img_dir, img_fname), os.path.join(data_dir, 'KIRC_st_small', img_fname)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:13<00:00,  4.41s/it]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "pt_dir = os.path.join(save_dir, 'pt_bi')\n",
    "graph_dir = os.path.join(save_dir, 'graphs')\n",
    "fail_list = []\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "for img_fname in tqdm([roi1, roi2, roi3]):\n",
    "    \n",
    "    #if int(img_fname.split('_')[2]) > 2: continue\n",
    "    #print(\"Processing...(%d/%d):\\t%s\" % (idx+1, len(os.listdir(seg_dir)), img_fname))\n",
    "    \n",
    "    img = np.array(Image.open(os.path.join(img_dir, img_fname)))\n",
    "    seg = np.array(Image.open(os.path.join(seg_dir, img_fname)))\n",
    "    ret, binary = cv2.threshold(seg, 127, 255, cv2.THRESH_BINARY) \n",
    "    contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\n",
    "    if len(contours) < 1: continue\n",
    "    \n",
    "    G = seg2graph(img, contours)\n",
    "\n",
    "    if G is None: \n",
    "        fail_list.append(img_fname)\n",
    "        continue\n",
    "\n",
    "\n",
    "    centroids = []\n",
    "    for u, attrib in G.nodes(data=True):\n",
    "        centroids.append(attrib['centroid'])\n",
    "    \n",
    "    cell_centroids = np.array(centroids).astype(np.float64)\n",
    "    dataset = cell_centroids\n",
    "    \n",
    "    start = None\n",
    "            \n",
    "    for idx, attrib in list(G.nodes(data=True)):\n",
    "        start = idx\n",
    "        flann = FLANN()\n",
    "        testset = np.array([attrib['centroid']]).astype(np.float64)\n",
    "        results, dists = flann.nn(dataset, testset, num_neighbors=5, algorithm = 'kmeans', branching = 32, iterations = 100, checks = 16)\n",
    "        results, dists = results[0], dists[0]\n",
    "        nns_fin = []\n",
    "       # assert (results.shape[0] < 6)\n",
    "        for i in range(1, len(results)):\n",
    "            G.add_edge(idx, results[i], weight = dists[i])\n",
    "            nns_fin.append(results[i])\n",
    "        #attrib['nn'] = list(nns_fin)\n",
    "\n",
    "    G = G.subgraph(max(nx.connected_components(G), key=len))\n",
    "\n",
    "    #for idx, attrib in list(G.nodes(data=True)):\n",
    "    #    cv2.circle(img, tuple(attrib['centroid']), 8, (0, 255, 0), -1)\n",
    "    \n",
    "    cv2.drawContours(img, contours, -1, (0,255,0), 2)\n",
    "    \n",
    "    for n, nbrs in G.adjacency():\n",
    "        for nbr, eattr in nbrs.items():\n",
    "            cv2.line(img, tuple(G.nodes[n]['centroid']),  tuple(G.nodes[nbr]['centroid']), (0, 0, 255), 2)\n",
    "\n",
    "    Image.fromarray(img).save(os.path.join(graph_dir, img_fname))\n",
    "    \n",
    "    G = from_networkx(G)\n",
    "    \n",
    "    edge_attr_long = (G.weight.unsqueeze(1)).type(torch.LongTensor)\n",
    "    G.edge_attr = edge_attr_long \n",
    "    \n",
    "    edge_index_long = G['edge_index'].type(torch.LongTensor)\n",
    "    G.edge_index = edge_index_long\n",
    "    \n",
    "    x_float = G['x'].type(torch.FloatTensor)\n",
    "    G.x = x_float\n",
    "    \n",
    "    G['weight'] = None\n",
    "    G['nn'] = None\n",
    "    torch.save(G, os.path.join(pt_dir, img_fname[:-4]+'.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for img_fname in tqdm(os.listdir(seg_dir)):\n",
    "    \n",
    "    if int(img_fname.split('_')[2]) > 2: continue\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.load('../staintools/KIRC_st_cpc/pt_bi/TCGA-3Z-A93Z-01Z-00-DX1.79F4D1A6-ACDB-4AB1-B8A8-C1CEE617C734_roi_0_x_70176_y_35424_99.659.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Data(centroid=[102, 2], edge_attr=[502, 1], edge_index=[2, 502], x=[102, 1036])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 262,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{0,\n",
       "  1,\n",
       "  2,\n",
       "  3,\n",
       "  4,\n",
       "  5,\n",
       "  6,\n",
       "  7,\n",
       "  8,\n",
       "  9,\n",
       "  10,\n",
       "  11,\n",
       "  12,\n",
       "  13,\n",
       "  14,\n",
       "  15,\n",
       "  16,\n",
       "  17,\n",
       "  18,\n",
       "  19,\n",
       "  20,\n",
       "  21,\n",
       "  22,\n",
       "  23,\n",
       "  24,\n",
       "  25,\n",
       "  26,\n",
       "  27,\n",
       "  28,\n",
       "  29,\n",
       "  30,\n",
       "  31,\n",
       "  32,\n",
       "  33,\n",
       "  34,\n",
       "  35,\n",
       "  36,\n",
       "  37,\n",
       "  38,\n",
       "  39,\n",
       "  40,\n",
       "  41,\n",
       "  42,\n",
       "  43,\n",
       "  44,\n",
       "  45,\n",
       "  46,\n",
       "  47,\n",
       "  48,\n",
       "  49,\n",
       "  50,\n",
       "  51,\n",
       "  52,\n",
       "  53,\n",
       "  54,\n",
       "  55,\n",
       "  56,\n",
       "  57,\n",
       "  58,\n",
       "  59,\n",
       "  60,\n",
       "  61,\n",
       "  62,\n",
       "  63,\n",
       "  64,\n",
       "  65,\n",
       "  66,\n",
       "  67,\n",
       "  68,\n",
       "  69,\n",
       "  70,\n",
       "  71,\n",
       "  72,\n",
       "  73,\n",
       "  74,\n",
       "  75,\n",
       "  76,\n",
       "  77,\n",
       "  78,\n",
       "  79,\n",
       "  80,\n",
       "  81,\n",
       "  82,\n",
       "  83,\n",
       "  84,\n",
       "  85,\n",
       "  86,\n",
       "  87,\n",
       "  88,\n",
       "  89,\n",
       "  90,\n",
       "  91,\n",
       "  92,\n",
       "  93,\n",
       "  94,\n",
       "  95,\n",
       "  96,\n",
       "  97,\n",
       "  98,\n",
       "  99,\n",
       "  100,\n",
       "  101,\n",
       "  102,\n",
       "  103,\n",
       "  104,\n",
       "  105,\n",
       "  106,\n",
       "  107,\n",
       "  108,\n",
       "  109,\n",
       "  110,\n",
       "  111,\n",
       "  112,\n",
       "  113,\n",
       "  114,\n",
       "  115,\n",
       "  116,\n",
       "  117,\n",
       "  118,\n",
       "  119,\n",
       "  120,\n",
       "  121,\n",
       "  122,\n",
       "  123,\n",
       "  124,\n",
       "  125,\n",
       "  126,\n",
       "  127,\n",
       "  128,\n",
       "  129,\n",
       "  130,\n",
       "  131,\n",
       "  132,\n",
       "  133}]"
      ]
     },
     "execution_count": 262,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(nx.connected_components(G))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 264,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<networkx.classes.graph.Graph at 0x7fb3167cdc90>"
      ]
     },
     "execution_count": 264,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 263,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<networkx.classes.graph.Graph at 0x7fb2c0301950>"
      ]
     },
     "execution_count": 263,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 255,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0,\n",
       " 1,\n",
       " 2,\n",
       " 3,\n",
       " 4,\n",
       " 5,\n",
       " 6,\n",
       " 7,\n",
       " 8,\n",
       " 9,\n",
       " 10,\n",
       " 11,\n",
       " 12,\n",
       " 13,\n",
       " 14,\n",
       " 15,\n",
       " 16,\n",
       " 17,\n",
       " 18,\n",
       " 19,\n",
       " 20,\n",
       " 21,\n",
       " 22,\n",
       " 23,\n",
       " 24,\n",
       " 25,\n",
       " 26,\n",
       " 27,\n",
       " 28,\n",
       " 29,\n",
       " 30,\n",
       " 31,\n",
       " 32,\n",
       " 33,\n",
       " 34,\n",
       " 35,\n",
       " 36,\n",
       " 37,\n",
       " 38,\n",
       " 39,\n",
       " 40,\n",
       " 41,\n",
       " 42,\n",
       " 43,\n",
       " 44,\n",
       " 45,\n",
       " 46,\n",
       " 47,\n",
       " 48,\n",
       " 49,\n",
       " 50,\n",
       " 51,\n",
       " 52,\n",
       " 53,\n",
       " 54,\n",
       " 55,\n",
       " 56,\n",
       " 57,\n",
       " 58,\n",
       " 59,\n",
       " 60,\n",
       " 61,\n",
       " 62,\n",
       " 63,\n",
       " 64,\n",
       " 65,\n",
       " 66,\n",
       " 67,\n",
       " 68,\n",
       " 69,\n",
       " 70,\n",
       " 71,\n",
       " 72,\n",
       " 73,\n",
       " 74,\n",
       " 75,\n",
       " 76,\n",
       " 77,\n",
       " 78,\n",
       " 79,\n",
       " 80,\n",
       " 81,\n",
       " 82,\n",
       " 83,\n",
       " 84,\n",
       " 85,\n",
       " 86,\n",
       " 87,\n",
       " 88,\n",
       " 89,\n",
       " 90,\n",
       " 91,\n",
       " 92,\n",
       " 93,\n",
       " 94,\n",
       " 95,\n",
       " 96,\n",
       " 97,\n",
       " 98,\n",
       " 99,\n",
       " 100,\n",
       " 101,\n",
       " 102,\n",
       " 103,\n",
       " 104,\n",
       " 105,\n",
       " 106,\n",
       " 107,\n",
       " 108,\n",
       " 109,\n",
       " 110,\n",
       " 111,\n",
       " 112,\n",
       " 113,\n",
       " 114,\n",
       " 115,\n",
       " 116,\n",
       " 117,\n",
       " 118,\n",
       " 119,\n",
       " 120,\n",
       " 121,\n",
       " 122,\n",
       " 123,\n",
       " 124,\n",
       " 125,\n",
       " 126,\n",
       " 127,\n",
       " 128,\n",
       " 129,\n",
       " 130,\n",
       " 131,\n",
       " 132,\n",
       " 133}"
      ]
     },
     "execution_count": 255,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 224,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0, 17, 27, 13, 28], dtype=int32)"
      ]
     },
     "execution_count": 224,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 231,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(134, 2)"
      ]
     },
     "execution_count": 231,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 200,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13.0"
      ]
     },
     "execution_count": 200,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.linalg.norm(results[3]-results[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 205,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([   0., 2548., 5777., 6698., 7001.])"
      ]
     },
     "execution_count": 205,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 190,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'Data' object has no attribute 'node'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-190-74f3b11c05eb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m: 'Data' object has no attribute 'node'"
     ]
    }
   ],
   "source": [
    "G.node(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results[]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx, attrib = list(G.nodes(data=True))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 17., 487.]])"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "testset = np.array([attrib['centroid']]).astype(np.float64)\n",
    "testset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "results, dists = flann.nn(dataset, testset, num_neighbors=5, algorithm = 'kmeans', branching = 32, iterations = 7, checks = 16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[   0., 2548., 5777., 6698., 7001.]])"
      ]
     },
     "execution_count": 107,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0, 17, 27, 13, 28]], dtype=int32)"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 2\n"
     ]
    }
   ],
   "source": [
    "print(len(G.nodes[n]), len(G.nodes[nbr]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(len(G.nodes[n]) > 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(G.nodes[n]) > 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'centroid': [353, 482],\n",
       " 'x': array([ 11.95111752,  14.6584177 , 150.54612732, ...,   0.        ,\n",
       "          0.        ,   0.        ])}"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "G.nodes[nbr]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
